What do the pdf updates look like?¶

Exploring how the input distribution changes under Fisher information updates.

In [1]:
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.core.pylabtools import figsize

import seaborn as sns
import plotly.express as px

import numpy as np
import pandas as pd
import polars as pl

import statsmodels.formula.api as smf
import statsmodels.api as sm

import matplotlib
In [2]:
import torch

from discriminationAnalysis import Fisher_smooth_fits
from basicModel import EstimateAngle
from adaptableModel import AdaptableEstimator, AngleDistribution

from adapt_fit_loop import moving_average

pdf update for the concentrated case¶

In [3]:
import glob

fig, ax = plt.subplots(subplot_kw={'projection': 'polar'})
plt.title('Fisher Information: trained networks')

ex4_conc_dir = 'trainedParameters/Exp4_conc/'

FIcurves_conc = []

for rep in range(6):
    trained_ckpt =  glob.glob(ex4_conc_dir + f'rep{rep}/epoch*')[0]
    
    model = EstimateAngle.load_from_checkpoint(trained_ckpt)

    fi = Fisher_smooth_fits(model, 0., np.pi, N_cov=500, Samp_cov=500)
    FIcurves_conc.append(fi)
    
    plt.plot(np.linspace(0, 2*np.pi, 500), fi)
No description has been provided for this image
In [4]:
np.array(FIcurves_conc).min(1)
Out[4]:
array([14.69482853,  9.7832375 ,  0.17993534, 17.95820599, 25.27543399,
       18.58125288])
In [ ]:
fisher_curves = np.array(FIcurves_conc)

smoothed_mean_fisher = moving_average(np.mean(fisher_curves, axis=0))
In [ ]:
fig, ax = plt.subplots(subplot_kw={'projection': 'polar'})
plt.plot(np.linspace(0, 2*np.pi, 500), smoothed_mean_fisher)
Out[ ]:
[<matplotlib.lines.Line2D at 0x13bdcaeb0>]
No description has been provided for this image
In [ ]:
 
In [ ]:
unif = np.ones(500)

p1 = unif / smoothed_mean_fisher**0.5
p1 = p1 / p1.sum()
In [ ]:
fig, ax = plt.subplots(subplot_kw={'projection': 'polar'})
 
plt.plot(np.linspace(0, 2*np.pi, 500), p1, 'k')
Out[ ]:
[<matplotlib.lines.Line2D at 0x13becb1c0>]
No description has been provided for this image

Experiment 4 concentrated was trained with the data concentrated around pi /2.

Note for these plots, I am doubling the angular scale for ease of visualization of the periodic signal. Thus, angles close to zero are, in fact, orthogonal to the pi/2 angles.

What is the idea of the iteration?

  1. this network has its Fisher information concentrated at pi/2.
  2. fine-tuning on the reverse, concentrated around 0, should, thus, remove this sensitivity bias.

Finetuning what? Should we finetune this network (the already trained one), or the previous network that we trained in order to get this response to the current stimulus distribution?

Note for the record:¶

These are seperately trained versions of the network that we are averaging over to determine the Fisher information.

In [ ]:
fig, ax = plt.subplots(subplot_kw={'projection': 'polar'})
 
for fi in FIcurves_conc:
    c = unif / fi**0.5
    plt.plot(np.linspace(0, 2*np.pi, 500), c / c.sum())
No description has been provided for this image

The individual runs can produce very noisy results.

Ideas as to why the convergence fails:¶

  1. There are modes that are amplified, rather than damped in the fitting process.
    • essentially this is the case if the distribution above is more extreme than the distribution that the network was trained on.
In [ ]:
from scipy.stats import vonmises
from scipy.integrate import trapezoid

fig, ax = plt.subplots(subplot_kw={'projection': 'polar'})

x = np.linspace(0, 2*np.pi, 500)


p1 = unif / smoothed_mean_fisher**0.5
p1 = p1 / trapezoid(p1, x)

plt.plot(x, vonmises(8., 0).pdf( x))
 
plt.plot(x, p1, 'k')
Out[ ]:
[<matplotlib.lines.Line2D at 0x13bfa2e80>]
No description has been provided for this image

I have to be careful here about how these are normalized, and to what x domain. However, the fact that the values near zero are 'rounded out' is hopeful.

In [ ]:
from adaptableModel import AngleDistribution
In [ ]:
# this is the former default
a= AngleDistribution(p1, [-np.pi, np.pi])
In [ ]:
update_samples = a.sample(10000)
prior_samples =  vonmises(8., 0).rvs( 10000)
In [ ]:
n_update, b_update = np.histogram(update_samples, bins=50, density=True)
plt.plot(b_update[1:], n_update)

n_prior, b_prior = np.histogram(prior_samples, bins=50, density=True)
plt.plot(b_prior[1:], n_prior, 'k')
Out[ ]:
[<matplotlib.lines.Line2D at 0x13c0180a0>]
No description has been provided for this image

Ok, this is reasonable good to show the distribution is, in fact getting more uniform, at least for this one step.

In [ ]:
test = AngleDistribution( vonmises(8., 0).pdf( x), [-np.pi, np.pi])
In [ ]:
test_num, test_bins = np.histogram(test.sample(10000), bins=50, density=True)

plt.plot(test_bins[1:], test_num)
plt.plot(b_prior[1:], n_prior, 'k')
Out[ ]:
[<matplotlib.lines.Line2D at 0x13c09f340>]
No description has been provided for this image

These are possible issues:¶

  • AngleDistribution appears to rotate the angles by 180 degrees
  • AngleDistribution also uses angles betwee -pi and pi rather that 0 to pi (or some other half length parameterization), which is inconsistent with the generation.
In [ ]:
test_num, test_bins = np.histogram(test.sample(10000), bins=50, density=True)
rotated_num, rotated_bins = np.histogram( vonmises(8.,np.pi).rvs(10000), bins=50, density=True)

plt.plot(test_bins[1:], test_num)
plt.plot(rotated_bins[1:], rotated_num, '--k')
Out[ ]:
[<matplotlib.lines.Line2D at 0x13c11f550>]
No description has been provided for this image

The density itself seems reasonably unchanged.

In [ ]:
 

What do the iterates look like?¶

In [70]:
from adaptableModel import AngleDistribution

data =pd.read_pickle('trainedParameters/Exp6/untrained/iterate_data.pickle')

colors = plt.cm.viridis(np.linspace(0,1,8))
for row in data[ data.measurement=='probability'].to_dict(orient="records"):
    itr = row['iteration']
    if itr > 0:
        itr = itr-1
    
    dist = AngleDistribution(row['data'], [0, np.pi])
    plt.plot(np.linspace(0, np.pi, dist.npoints-1), dist.bin_probs, c=colors[itr])

plt.title('Untrained Network Trajectory')
Out[70]:
Text(0.5, 1.0, 'Untrained Network Trajectory')
No description has been provided for this image
In [72]:
data =pd.read_pickle('trainedParameters/Exp6/uniform/iterate_data.pickle')

colors = plt.cm.viridis(np.linspace(0,1,8))
for row in data[ data.measurement=='probability'].to_dict(orient="records"):
    itr = row['iteration']
    if itr > 0:
        itr = itr-1
    
    dist = AngleDistribution(row['data'], [0, np.pi])
    plt.plot(np.linspace(0, np.pi, dist.npoints-1), dist.bin_probs, c=colors[itr])

plt.title('Uniform Network Trajectory')
Out[72]:
Text(0.5, 1.0, 'Uniform Network Trajectory')
No description has been provided for this image
In [98]:
data =pd.read_pickle('trainedParameters/Exp6/concentrated/iterate_data.pickle')

colors = plt.cm.viridis(np.linspace(0,1,8))
for row in data[ data.measurement=='probability'].to_dict(orient="records"):
    itr = row['iteration']
    if itr > 0:
        itr = itr-1
    
    dist = AngleDistribution(row['data'], [0, np.pi])
    plt.plot(np.linspace(0, np.pi, dist.npoints-1), dist.bin_probs, c=colors[itr])

plt.title('Concentrated Network Trajectory')
Out[98]:
Text(0.5, 1.0, 'Concentrated Network Trajectory')
No description has been provided for this image
In [92]:
xs =np.linspace(-1, 1, 201)
for i in range(8):
    plt.plot(xs, i* xs, c=colors[i])
No description has been provided for this image

The trend here is clear.¶

It seems that the small deviations in the probability distributions are being successively amplified through the process of finetuning and fitting the Fisher information.

This is a little surprising:
The concentrated network appears to have its biases pretty well removed by the initial fine-tuning, but the newly introduced small biases aren't trained away? How does that make sense?

In [ ]:
 
In [147]:
concentrated_data =pd.read_pickle('trainedParameters/Exp6/concentrated/iterate_data.pickle')
In [148]:
raw = concentrated_data[(concentrated_data.measurement == 'FI') & 
                          (concentrated_data.iteration == 0)]['data']

xs = np.linspace(0, np.pi,500)
for row in raw:
    plt.plot(xs, row)
    
smoothed_mean_fisher = moving_average(np.mean(raw, axis=0))
plt.plot(xs, smoothed_mean_fisher, 'k')
Out[148]:
[<matplotlib.lines.Line2D at 0x141221a30>]
No description has been provided for this image
In [149]:
recorded = concentrated_data[(concentrated_data.iteration == 2) &
                             (concentrated_data.measurement == 'probability')]['data'].iloc[0]

plt.plot(xs, 1./smoothed_mean_fisher**0.5)
plt.plot(xs, recorded, '--k')
Out[149]:
[<matplotlib.lines.Line2D at 0x140ed6550>]
No description has been provided for this image

Ok, good double check: the probability distribution is calculated as expected.

interim summary:¶

This is a seemingly paradoxical result:

  1. In this case, $p(s_1) - p(s_2) > 0$ means that $p'(s_1) - p'(s_2) > p(s_1) - p(s_2)$. Thus, $p(s_1) / \sqrt{I(s_1)} - p(s_2) / \sqrt{I(s_2)} > p(s_1) - p(s_2)$...

Is there actually a negative correlation between the probability density that the networks were trained on and the Fisher information in the resulting networks?¶

What if we plot the pdf that we fine-tuned on vs the Fisher information measurements?

In [156]:
concentrated_data = pd.read_pickle('trainedParameters/Exp6/concentrated/iterate_data.pickle')
uniform_data = pd.read_pickle('trainedParameters/Exp6/uniform/iterate_data.pickle')
untrained_data = pd.read_pickle('trainedParameters/Exp6/untrained/iterate_data.pickle')

# fix the indexing issues in the experiment script
inds =concentrated_data[ concentrated_data['measurement'] == 'probability'].index
concentrated_data.loc[inds, 'iteration'] = range(8)
uniform_data.loc[inds, 'iteration'] = range(8)
untrained_data.loc[inds, 'iteration'] = range(8)
In [ ]:
 
In [231]:
pt = concentrated_data.pivot_table(index='iteration', columns='measurement',
                                   values='data', aggfunc='mean')

for row in pt.to_dict(orient="records")[::-1]:
    dist = AngleDistribution(row['probability'], [0, np.pi])
    plt.plot(dist.bin_probs, row['FI'][1:], '.')
plt.xlabel('probability')
plt.ylabel('Fisher Info')
    
Out[231]:
Text(0, 0.5, 'Fisher Info')
No description has been provided for this image
In [234]:
pt = uniform_data.pivot_table(index='iteration', columns='measurement',
                                   values='data', aggfunc='mean')

for row in pt.to_dict(orient="records")[::-1]:
    dist = AngleDistribution(row['probability'], [0, np.pi])
    plt.plot(dist.bin_probs, row['FI'][1:], '.')
plt.xlabel('probability')
plt.ylabel('Fisher Info')
    
Out[234]:
Text(0, 0.5, 'Fisher Info')
No description has been provided for this image
In [235]:
pt = untrained_data.pivot_table(index='iteration', columns='measurement',
                                   values='data', aggfunc='mean')

for row in pt.to_dict(orient="records")[::-1]:
    dist = AngleDistribution(row['probability'], [0, np.pi])
    plt.plot(dist.bin_probs, row['FI'][1:], '.')
plt.xlabel('probability')
plt.ylabel('Fisher Info')
Out[235]:
Text(0, 0.5, 'Fisher Info')
No description has been provided for this image

Note that I'm plotting these in reverse order (so grey is the first iteration, then pink, etc).

As the iterations continue (and the probabilities diffuse outward), we see the the gradual emergence of an inverse correlation between probability in the training set and the Fisher information of the learned mapping.

This is the exact opposite of the case previously!

Did the previous case actually show what I thought?¶

In [3]:
import glob

fig, ax = plt.subplots(subplot_kw={'projection': 'polar'})
plt.title('Fisher Information: trained networks')

ex4_conc_dir = 'trainedParameters/Exp4_conc/'

FIcurves_conc = []

for rep in range(6):
    trained_ckpt =  glob.glob(ex4_conc_dir + f'rep{rep}/epoch*')[0]
    
    model = EstimateAngle.load_from_checkpoint(trained_ckpt)

    fi = Fisher_smooth_fits(model, 0., np.pi, N_cov=500, Samp_cov=500)
    FIcurves_conc.append(fi)

    print(rep)
    plt.plot(np.linspace(0, 2*np.pi, 500), fi)
0
1
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[3], line 15
     11 trained_ckpt =  glob.glob(ex4_conc_dir + f'rep{rep}/epoch*')[0]
     13 model = EstimateAngle.load_from_checkpoint(trained_ckpt)
---> 15 fi = Fisher_smooth_fits(model, 0., np.pi, N_cov=500, Samp_cov=500)
     16 FIcurves_conc.append(fi)
     18 print(rep)

File ~/Documents/NNtraining/angleFineTuning/discriminationAnalysis.py:35, in Fisher_smooth_fits(model, theta_start, theta_end, N_mean, N_cov, Samp_cov)
     33 FI = []
     34 for i, angle in enumerate(cov_angles):
---> 35     noisy_results = generate_samples(model,  Samp_cov*[angle])
     36     invcov = np.linalg.inv(np.cov(noisy_results.T))
     38     FI.append(derivs[:, deriv_cov_ratio*i] @ invcov @ derivs[:, deriv_cov_ratio*i])

File ~/Documents/NNtraining/angleFineTuning/discriminationAnalysis.py:98, in generate_samples(model, thetas, pixelDim, shotNoise, noiseVar)
     96 def generate_samples(model, thetas, pixelDim=101, shotNoise=0.8, noiseVar=20):
     97     """ generate samples from the model """
---> 98     samples = model.forward(generateGrating(thetas, pixelDim=pixelDim,
     99                                             shotNoise=shotNoise, noiseVar=noiseVar)
    100                             ).detach().numpy()
    101     return samples

File ~/Documents/NNtraining/angleFineTuning/datageneration/stimulusGeneration.py:35, in generateGrating(thetas, frequency, pixelDim, shotNoise, noiseVar)
     32 # add noise to the generated gratings
     33 noiseLocations = binomial(1, shotNoise,
     34                           size=(len(thetas), pixelDim, pixelDim))
---> 35 noiseMagnitude = normal(scale=noiseVar**0.5,
     36                         size=(len(thetas), pixelDim, pixelDim))
     38 Z = torch.clamp(Z + torch.tensor(
     39                      noiseLocations * noiseMagnitude, dtype=torch.float32
     40                     ), min=-1., max=1.)
     42 r2 = X**2 + Y**2

KeyboardInterrupt: 
No description has been provided for this image
In [6]:
from scipy.stats import vonmises
In [20]:
dist_exp4 = vonmises(8., np.pi/2)
samples= dist_exp4.rvs(10000)
In [35]:
plt.hist(samples, bins=20)
plt.hist(samples %np.pi, bins=20)
Out[35]:
(array([   6.,    3.,   33.,   71.,  184.,  336.,  685., 1171., 1462.,
        1608., 1539., 1239.,  805.,  480.,  224.,   87.,   35.,   20.,
           9.,    3.]),
 array([0.13851379, 0.28685553, 0.43519727, 0.58353902, 0.73188076,
        0.8802225 , 1.02856424, 1.17690599, 1.32524773, 1.47358947,
        1.62193121, 1.77027295, 1.9186147 , 2.06695644, 2.21529818,
        2.36363992, 2.51198167, 2.66032341, 2.80866515, 2.95700689,
        3.10534863]),
 <BarContainer object of 20 artists>)
No description has been provided for this image
In [33]:
plt.plot(samples, samples % np.pi, '.')
Out[33]:
[<matplotlib.lines.Line2D at 0x1406528b0>]
No description has been provided for this image

Ok, this is nice: our distributions are very concentrated, so the way that we projected the data down to (0, pi) doesn't really matter.

In [44]:
samples= dist_exp4.rvs(50000) % np.pi
count, bins = np.histogram(samples, bins=np.linspace(0, np.pi, 501), density=True)
In [48]:
for y in FIcurves_conc:
    plt.plot(count, y, '.')

plt.xlabel('Probability density')
plt.ylabel('Fisher Information');
No description has been provided for this image
In [54]:
fig, ax = plt.subplots(subplot_kw={'projection': 'polar'})

for y in FIcurves_conc:
    plt.plot(np.linspace(0, 2*np.pi, 500), y)

plt.plot(np.linspace(0, 2*np.pi, 500), 10000*count,'k')
Out[54]:
[<matplotlib.lines.Line2D at 0x1423fa370>]
No description has been provided for this image

Yes, indeed. Here we see a very different dependence between the probability density of the training distribution and the Fisher information in the learned network.

High density -> high Fisher information, and vice versa.

It's good to have triple checked this.

In [ ]:
 

Notes:¶

Thinking more about it, it is very strange to get such clean behavior. This is an average of multiple retrained samples, but it looks like there is basically no noise?

In [78]:
fig, ax = plt.subplots(2, 3, subplot_kw={'projection': 'polar'})

for ind in range(6):
    for fi in concentrated_data[(concentrated_data.measurement == 'FI')
                              & (concentrated_data.iteration == ind)]['data']:
        plt.subplot(2,3,ind+1)
        plt.plot(np.linspace(0, 2*np.pi, 500), fi)
No description has been provided for this image

No, but seriously, how are these so damn similar???

In [84]:
ckpts = glob.glob('trainedParameters/Exp6/concentrated/iter0/*')
In [87]:
from adaptableModel import AdaptableEstimator, AngleDistribution
In [97]:
U = AngleDistribution(np.ones(500), [0., np.pi])
for ckpt in ckpts:
    model = AdaptableEstimator.load_from_checkpoint(ckpt, angle_dist=U)
    print( model.hparams.seed )
967369843898950914
15346541979810798859
12711906263299886879
14665669494729647154
16652516071571257433
In [98]:
fig, ax = plt.subplots(2, 3, subplot_kw={'projection': 'polar'})

for ind in range(6):
    for fi in uniform_data[(uniform_data.measurement == 'FI')
                              & (uniform_data.iteration == ind)]['data']:
        plt.subplot(2,3,ind+1)
        plt.plot(np.linspace(0, 2*np.pi, 500), fi)
No description has been provided for this image
In [99]:
fig, ax = plt.subplots(2, 3, subplot_kw={'projection': 'polar'})

for ind in range(6):
    for fi in untrained_data[(untrained_data.measurement == 'FI')
                              & (untrained_data.iteration == ind)]['data']:
        plt.subplot(2,3,ind+1)
        plt.plot(np.linspace(0, 2*np.pi, 500), fi)
No description has been provided for this image
In [ ]:
 

Is it possibly overfitting?¶

The models that I'm using to assess the Fisher information aren't necessarily the best chekpoint, since I don't reload these from the file...

Indeed, running a quick check shows that final model parameters are not the optimal ones.

I can test this by comparing the recorded Fisher information to the fits to Fisher information.

In [119]:
concentrated_rerun = []

flat_dist = AngleDistribution(np.ones(500), [0., np.pi])

for iter in range(8):
    files = glob.glob(f'trainedParameters/Exp6/concentrated/iter{iter}/*')

    for checkpoint in files:
        model = AdaptableEstimator.load_from_checkpoint(checkpoint,
                                                        angle_dist=flat_dist,
                                                        max_epochs=0
                                                        )
        fi = Fisher_smooth_fits(model, 0., np.pi, N_mean=10000, N_cov=500, Samp_cov=500)
        
        row = {'iteration': iter, 'data': fi}
        concentrated_rerun.append(row)
    print(iter)
0
1
2
3
4
5
6
7
In [125]:
concentrated_rerun = pd.DataFrame(concentrated_rerun)
In [127]:
concentrated_rerun.groupby('iteration').mean()
Out[127]:
data
iteration
0 [7621.458407643084, 7004.0889070300955, 6895.1...
1 [8126.678988741916, 8286.100369626003, 7994.64...
2 [7979.954776736768, 7580.665071036531, 7865.98...
3 [8077.784758112767, 7796.0586707988205, 8172.8...
4 [8362.651840080049, 8816.162014109344, 8063.81...
5 [8089.661711997667, 7962.0554153758085, 7504.7...
6 [8612.287292079103, 8410.9201702268, 8117.0090...
7 [8226.916994561674, 8726.015627317034, 7589.94...
In [153]:
colors = plt.cm.viridis(np.linspace(0,1,8))
i =0

for row in concentrated_data[ concentrated_data.measurement == 'FI'
                            ].groupby('iteration').agg({'data': 'mean'}).sort_values('iteration')['data']:
    plt.plot(moving_average(row), c=colors[i])
    i += 1

plt.title('Fisher information - online record')
Out[153]:
Text(0.5, 1.0, 'Fisher information - online record')
No description has been provided for this image
In [154]:
colors = plt.cm.viridis(np.linspace(0,1,8))
i =0

for row in concentrated_rerun.groupby('iteration').mean().sort_values('iteration')['data']:
    plt.plot(moving_average(row), c=colors[i])
    i += 1

plt.title('Fisher Information - posthoc evaluation')
Out[154]:
Text(0.5, 1.0, 'Fisher Information - posthoc evaluation')
No description has been provided for this image
  1. There is some difference between the two sets of curves.

  2. The differences across iterations don't look strong enough to produce the trend that we've observed in the previous plots.

In [160]:
colors = plt.cm.viridis(np.linspace(0,1,9))

dist = AngleDistribution( np.ones(500), [0., np.pi])
plt.plot(dist.bin_probs, c=colors[0])

i = 1
for row in concentrated_data[ concentrated_data.measurement == 'FI'
                            ].groupby('iteration').agg({'data': 'mean'}).sort_values('iteration')['data']:

    new_values = dist.values / moving_average(row)**0.5
    dist = AngleDistribution(new_values, [0., np.pi])
    plt.plot(dist.bin_probs, c=colors[i])
    i += 1

plt.title('Probability distribution')
Out[160]:
Text(0.5, 1.0, 'Probability distribution')
No description has been provided for this image
In [161]:
colors = plt.cm.viridis(np.linspace(0,1,9))

dist = AngleDistribution( np.ones(500), [0., np.pi])
plt.plot(dist.bin_probs, c=colors[0])

i = 1
for row in concentrated_rerun.groupby('iteration').mean().sort_values('iteration')['data']:
    new_values = dist.values / moving_average(row)**0.5
    dist = AngleDistribution(new_values, [0., np.pi])
    plt.plot(dist.bin_probs, c=colors[i])
    i += 1
No description has been provided for this image

Interesting result¶

Ok, so it looks like the thing that causes this divergence is actually the fact that the Fisher information deviates from uniform in the same characteristic way every time we refit the neural networks.

That is to say it's because the Fisher information curves all have the same large scale shape.

This means that we divide by the same thing every time, which in turn causes the amplification of deviations.

That also explains the very deterministic seeming nature of the deviations: it really is repeated division by the same curve.

The question is why its the same?¶

Why do the networks all share this structure??

  • Is it the shared initialization?
  • Is is a randomization failure?
  • Is it an architecture failure?
  • Is it a Fisher information fitting failure?

Seed sharing¶

In [164]:
for ckpt in glob.glob('trainedParameters/Exp6/concentrated/iter0/*'):
    dist = AngleDistribution(np.ones(500), [0, np.pi])
    model = AdaptableEstimator.load_from_checkpoint( ckpt, angle_dist=dist)
    print(model.hparams.seed)
967369843898950914
15346541979810798859
12711906263299886879
14665669494729647154
16652516071571257433
In [165]:
for ckpt in glob.glob('trainedParameters/Exp6/concentrated/iter1/*'):
    dist = AngleDistribution(np.ones(500), [0, np.pi])
    model = AdaptableEstimator.load_from_checkpoint( ckpt, angle_dist=dist)
    print(model.hparams.seed)
10666304192740685704
12834590379526135592
10890395108299221895
11197852224913400512
13072276661895679149

The seeds that we recorded are not the same. Is it possible that the seed is not being set?

In [169]:
import torch
model1 = AdaptableEstimator.load_from_checkpoint(ckpt, angle_dist=dist, seed=torch.random.seed())
model2 = AdaptableEstimator.load_from_checkpoint(ckpt, angle_dist=dist, seed=torch.random.seed())
In [171]:
print(model1.hparams.seed, model2.hparams.seed)
13813948499615692093 7102135249791225486
In [172]:
model1.setup()
In [173]:
model2.setup()
In [181]:
model1.trainingData.angles
Out[181]:
tensor([1.3316, 0.9727, 2.6663,  ..., 0.2046, 1.8856, 2.8803])
In [178]:
model2.trainingData.angles
Out[178]:
tensor([1.5960, 1.6338, 2.3326,  ..., 0.0220, 2.0367, 2.5341])
In [183]:
plt.imshow(model1.trainingData.images[0] - model2.trainingData.images[0])
Out[183]:
<matplotlib.image.AxesImage at 0x171c5e820>
No description has been provided for this image

Certainly the data generated within the two models looks different.

Also, the initialization code seems to run upon loading the models.

I'm pretty confident that the seeds are different between the models.¶

Single iterates¶

In [193]:
for fi in concentrated_data[(concentrated_data['measurement'] == 'FI')]['data']:
    plt.plot(fi)
No description has been provided for this image

It definitely looks more correlated than I would expect

In [210]:
colors = plt.cm.viridis(np.linspace(0,1,6))

dist = AngleDistribution( np.ones(500), [0., np.pi])
plt.plot(dist.bin_probs, c=colors[0])

i = 1
for row in concentrated_data[(concentrated_data.measurement == 'FI') &
                             (concentrated_data.iteration == 0)
                            ]['data']:

    new_values = dist.values / moving_average(row)**0.5
    dist = AngleDistribution(new_values, [0., np.pi])
    plt.plot(dist.bin_probs, c=colors[i])
    i += 1

plt.title('Do all replicates in each iteration have the same structure?')
Out[210]:
Text(0.5, 1.0, 'Do all replicates in each iteration have the same structure?')
No description has been provided for this image

Yep, this is the same behavior as across iterates: there is a surprising amount of shared structure between the replicates

In [212]:
colors = plt.cm.viridis(np.linspace(0,1,6))

dist = AngleDistribution( np.ones(500), [0., np.pi])
plt.plot(dist.bin_probs, c=colors[0])

i = 1
for row in concentrated_data[(concentrated_data.measurement == 'FI') &
                             (concentrated_data.iteration == 0)
                            ]['data']:

    new_values = dist.values / row**0.5
    dist = AngleDistribution(new_values, [0., np.pi])
    plt.plot(dist.bin_probs, c=colors[i])
    i += 1

plt.title('What about the un-smoothed versions?')
Out[212]:
Text(0.5, 1.0, 'What about the un-smoothed versions?')
No description has been provided for this image

I mean, yes. This is essentially the same as the behavior that we saw above.

So the smoothing is not the cause of the similarity.

Is it a result of the shared initialization?¶

In this way, initialization looks to be a sort of constraint on the network that is being pulled out of the noise?

Experiment: run different initializations, and see if the Fisher information of the learned networks is dependent on the initialization.

In [3]:
data= pd.read_pickle('experiment_result/ex6_initialization.pickle')
data
Out[3]:
rep method Fisher
0 0 loaded [7482.413039576917, 6917.13476382538, 6849.097...
1 0 loaded [10166.779750083673, 8939.67076884833, 9921.07...
2 0 loaded [8342.131900495337, 7754.1581290318945, 8720.9...
3 0 set state [10246.708972247876, 10534.179180788991, 10474...
4 0 set state [8194.859422943355, 9120.434044892028, 10009.6...
5 0 set state [9891.966864509444, 10917.245571615525, 9565.3...
6 1 loaded [9103.473100671537, 8093.2921956849395, 7738.5...
7 1 loaded [8049.016757678765, 9203.03980910144, 10708.35...
8 1 loaded [8168.780814261347, 7271.36786233274, 7416.672...
9 1 set state [7484.556178987075, 7061.951097034991, 7524.17...
10 1 set state [8361.456949634292, 9376.600646718349, 8269.81...
11 1 set state [7888.846354406538, 8725.683978198615, 8975.97...
12 2 loaded [7597.5959725953635, 8650.741227010567, 7785.3...
13 2 loaded [8212.265420675509, 8202.753571510018, 9572.20...
14 2 loaded [8401.989415222208, 8165.419969414646, 9289.19...
15 2 set state [7875.954741235541, 8403.249508216057, 7851.75...
16 2 set state [8102.391485970146, 8428.492809881152, 7607.05...
17 2 set state [7750.2276935275795, 8183.110812613867, 9789.5...
18 3 loaded [7888.236253710033, 7438.313531107031, 7802.09...
19 3 loaded [9388.040978595753, 11182.55700286475, 10760.2...
20 3 loaded [9358.530614150337, 8728.840107521382, 8636.65...
21 3 set state [7974.953282728307, 8173.304388463969, 8817.06...
22 3 set state [7547.519832147966, 7306.104821011238, 5886.70...
23 3 set state [8097.5375863064755, 7504.368573390664, 7586.6...
24 4 loaded [9767.59402579983, 8139.895565234726, 9224.221...
25 4 loaded [8548.035607218688, 7886.105720024892, 7104.46...
26 4 loaded [8347.773943916216, 9773.953499879903, 8465.08...
27 4 set state [7509.38839085409, 7101.0802125403625, 8569.01...
28 4 set state [9480.810211138492, 9434.11025135334, 11616.58...
29 4 set state [9012.03739098086, 8643.40802404131, 8002.1764...
30 5 loaded [8199.180526830847, 7921.528926425352, 7307.44...
31 5 loaded [7655.612771660323, 7651.739397811222, 8399.38...
32 5 loaded [7890.474416216101, 7784.39781415223, 7951.540...
33 5 set state [9467.164749292322, 8209.064123106846, 9037.32...
34 5 set state [8637.367718021858, 8000.276341570653, 8157.63...
35 5 set state [7879.291543139089, 6833.589572237475, 8048.05...
In [23]:
plt.subplots(3,2)

for row in data[data.method == 'loaded'].to_dict(orient="records"):
    plt.subplot(3,2,row['rep']+1)
    plt.plot(moving_average(row['Fisher']))
No description has been provided for this image
In [24]:
plt.subplots(3,2)

for row in data[data.method == 'set state'].to_dict(orient="records"):
    plt.subplot(3,2,row['rep']+1)
    plt.plot(moving_average(row['Fisher']))
No description has been provided for this image

Ok, its hard to tell by eye whether the series are more similar within bins than between bins

In [146]:
colors = plt.cm.viridis(np.linspace(0,1,7))

def group_std(x): return x.to_numpy().std()

xs = np.linspace(0, np.pi, 500)

i=1
for row in data[data.method == 'loaded'].groupby('rep'
                                       ).agg({'Fisher':['mean', group_std]}
                                       ).to_dict(orient="records"):
    mean = row[('Fisher', 'mean')]
    err = row[('Fisher','group_std')]
    plt.plot(xs, mean, c=colors[i])
    plt.plot(xs, mean-err, '--', c=colors[i])
    plt.plot(xs, mean+err, '--', c=colors[i])

    i+=1
plt.title('Replicate means and variances')
Out[146]:
Text(0.5, 1.0, 'Replicate means and variances')
No description has been provided for this image

That visualization is worthless

In [148]:
sns.heatmap(np.corrcoef( np.array(data['Fisher'].to_list()) ))
plt.title('Pearson correlation heat map');
No description has been provided for this image
In [155]:
data
Out[155]:
rep method Fisher
0 0 loaded [7482.413039576917, 6917.13476382538, 6849.097...
1 0 loaded [10166.779750083673, 8939.67076884833, 9921.07...
2 0 loaded [8342.131900495337, 7754.1581290318945, 8720.9...
3 0 set state [10246.708972247876, 10534.179180788991, 10474...
4 0 set state [8194.859422943355, 9120.434044892028, 10009.6...
5 0 set state [9891.966864509444, 10917.245571615525, 9565.3...
6 1 loaded [9103.473100671537, 8093.2921956849395, 7738.5...
7 1 loaded [8049.016757678765, 9203.03980910144, 10708.35...
8 1 loaded [8168.780814261347, 7271.36786233274, 7416.672...
9 1 set state [7484.556178987075, 7061.951097034991, 7524.17...
10 1 set state [8361.456949634292, 9376.600646718349, 8269.81...
11 1 set state [7888.846354406538, 8725.683978198615, 8975.97...
12 2 loaded [7597.5959725953635, 8650.741227010567, 7785.3...
13 2 loaded [8212.265420675509, 8202.753571510018, 9572.20...
14 2 loaded [8401.989415222208, 8165.419969414646, 9289.19...
15 2 set state [7875.954741235541, 8403.249508216057, 7851.75...
16 2 set state [8102.391485970146, 8428.492809881152, 7607.05...
17 2 set state [7750.2276935275795, 8183.110812613867, 9789.5...
18 3 loaded [7888.236253710033, 7438.313531107031, 7802.09...
19 3 loaded [9388.040978595753, 11182.55700286475, 10760.2...
20 3 loaded [9358.530614150337, 8728.840107521382, 8636.65...
21 3 set state [7974.953282728307, 8173.304388463969, 8817.06...
22 3 set state [7547.519832147966, 7306.104821011238, 5886.70...
23 3 set state [8097.5375863064755, 7504.368573390664, 7586.6...
24 4 loaded [9767.59402579983, 8139.895565234726, 9224.221...
25 4 loaded [8548.035607218688, 7886.105720024892, 7104.46...
26 4 loaded [8347.773943916216, 9773.953499879903, 8465.08...
27 4 set state [7509.38839085409, 7101.0802125403625, 8569.01...
28 4 set state [9480.810211138492, 9434.11025135334, 11616.58...
29 4 set state [9012.03739098086, 8643.40802404131, 8002.1764...
30 5 loaded [8199.180526830847, 7921.528926425352, 7307.44...
31 5 loaded [7655.612771660323, 7651.739397811222, 8399.38...
32 5 loaded [7890.474416216101, 7784.39781415223, 7951.540...
33 5 set state [9467.164749292322, 8209.064123106846, 9037.32...
34 5 set state [8637.367718021858, 8000.276341570653, 8157.63...
35 5 set state [7879.291543139089, 6833.589572237475, 8048.05...
In [153]:
sns.heatmap(np.corrcoef( np.array(data['Fisher'].to_list()) ), vmin=-0.3, vmax=0.5)
plt.title('Pearson correlation heat map - zoom');
No description has been provided for this image

There doesn't really seem to be much structure in this measurement. We are looking for structure in the form of squares of length 3 or 6 along the diagonal.

Maybe I can convince myself that there is some such structure (eg 12-26, 7-10, 24-26), but it doesn't align with the the changes in initialization.

I conclude that either this is a poor measure, or there is no such initilization structure.

Positive control¶

In [157]:
concentrated_data = pd.read_pickle('trainedParameters/Exp6/concentrated/iterate_data.pickle')
uniform_data = pd.read_pickle('trainedParameters/Exp6/uniform/iterate_data.pickle')
untrained_data = pd.read_pickle('trainedParameters/Exp6/untrained/iterate_data.pickle')

# fix the indexing issues in the experiment script
inds =concentrated_data[ concentrated_data['measurement'] == 'probability'].index
concentrated_data.loc[inds, 'iteration'] = range(8)
uniform_data.loc[inds, 'iteration'] = range(8)
untrained_data.loc[inds, 'iteration'] = range(8)
In [168]:
plt.subplots(3,1)

for fi in concentrated_data[(concentrated_data['measurement'] == 'FI')]['data']:
    plt.subplot(3,1,1)
    plt.plot(fi)

    plt.subplot(3,1,3)
    plt.plot(fi)

for fi in uniform_data[(uniform_data['measurement'] == 'FI')]['data']:
    plt.subplot(3,1,2)
    plt.plot(fi)

    plt.subplot(3,1,3)
    plt.plot(fi)
No description has been provided for this image

By eye, the data seems to group into distinct curves. Is this captured by the mapping?

In [193]:
both = pd.concat([
    concentrated_data[(concentrated_data['measurement'] == 'FI')],
    uniform_data[(uniform_data['measurement'] == 'FI')]], axis=0)
In [196]:
sns.heatmap(np.corrcoef( np.array(both['data'].to_list()) ) )
Out[196]:
<Axes: >
No description has been provided for this image

Yeah, ok. This is pretty clear.

(note that these are the unsmoothed trajectories!)

In [198]:
all_traj = pd.concat([
    concentrated_data[(concentrated_data['measurement'] == 'FI')],
    uniform_data[(uniform_data['measurement'] == 'FI')],
    untrained_data[(untrained_data['measurement'] == 'FI')]
    ], axis=0)

sns.heatmap(np.corrcoef( np.array(all_traj['data'].to_list()) ) )
Out[198]:
<Axes: >
No description has been provided for this image

The within-class correlation is greatly diminished when we look at the untrained initialization. In fact, it looks similar to the results from the second experiment.

In [199]:
for fi in untrained_data[(untrained_data['measurement'] == 'FI')]['data']:
    plt.plot(fi)
No description has been provided for this image

Indeed, the untrained networks appear much more spread out than the pretrained ones!

However, this only shows up weakly in the Fisher information iteration

In [205]:
for row in untrained_data[untrained_data.measurement == 'FI'].groupby('iteration'
                                                            ).agg({'data': 'mean'}
                                                            )['data']:
    plt.plot(row)
No description has been provided for this image
In [214]:
from adaptableModel import AngleDistribution

colors = plt.cm.viridis(np.linspace(0,1,8))
for row in untrained_data[ untrained_data.measurement=='probability'].to_dict(orient="records"):
    itr = row['iteration']
    
    dist = AngleDistribution(row['data'], [0, np.pi])
    plt.plot(np.linspace(0, np.pi, dist.npoints-1), dist.bin_probs, c=colors[itr])

plt.title('Untrained Network Trajectory');
No description has been provided for this image
In [336]:
dist = AngleDistribution(np.ones(500), [0, np.pi])

colors = plt.cm.viridis(np.linspace(0,1,8))
i=0

for row in untrained_data[untrained_data.measurement == 'FI'].groupby('iteration'
                                                            ).agg({'data': 'mean'}
                                                            )['data']:
    smoothed_fi = moving_average(row)
    dist = AngleDistribution( dist.values / smoothed_fi, [0, np.pi])
    
    plt.plot(dist.bin_probs, c=colors[i])
    i+=1
plt.title('Untrained: recomputed bin probabilities');
No description has been provided for this image

What happens if I resample across iterates?

In [339]:
dist = AngleDistribution(np.ones(500), [0, np.pi])
colors = plt.cm.viridis(np.linspace(0,1,8))

to_sample = untrained_data[untrained_data.measurement == 'FI']

for iter in range(8):
    inds = np.random.choice(range(40), 1)
    sample_mean = np.mean(to_sample.iloc[inds]['data'].to_list(), axis=0)
    
    smoothed_fi = moving_average(sample_mean)
    dist = AngleDistribution( dist.values / smoothed_fi, [0, np.pi])
    
    plt.plot(dist.bin_probs, c=colors[iter])
plt.title('Untrained: random trajectory per iterate');
No description has been provided for this image

The same behavior emerges when we sample a random subset of iterations, and even when we simply use the first set of iterations:

The probability distributions seem to diverge away from uniform

In [291]:
dist.values[0:10]
Out[291]:
array([5.24937531e-32, 5.05582106e-32, 4.89830818e-32, 4.70749811e-32,
       4.60765458e-32, 4.56129761e-32, 4.54834814e-32, 4.52298889e-32,
       4.59708996e-32, 4.60470690e-32])

Oh boy, is it round off error??

In [338]:
dist = AngleDistribution(np.ones(500), [0, np.pi])
colors = plt.cm.viridis(np.linspace(0,1,8))

to_sample = untrained_data[untrained_data.measurement == 'FI']

for iter in range(8):
    inds = np.random.choice(range(40), 1)
    sample_mean = np.mean(to_sample.iloc[inds]['data'].to_list(), axis=0)
    
    smoothed_fi = moving_average(sample_mean)

    new_probs = dist.values / smoothed_fi
    new_probs = new_probs / new_probs.sum()
    dist = AngleDistribution(new_probs, [0, np.pi])
    
    plt.plot(dist.bin_probs, c=colors[iter])

plt.title('Renormalized to avoid round-off');
No description has been provided for this image

Nope, thats still not it.

I should fix this though!

In [332]:
dist = AngleDistribution(np.ones(500), [0, np.pi])
colors = plt.cm.viridis(np.linspace(0,1,8))

to_sample = untrained_data[untrained_data.measurement == 'FI']

for iter in range(8):
    sample_mean = 100*np.random.rand(500) +100
    
    smoothed_fi = moving_average(sample_mean)

    new_probs = dist.values / smoothed_fi
    new_probs = new_probs / new_probs.sum()
    dist = AngleDistribution(new_probs, [0, np.pi])
    
    plt.plot(dist.bin_probs, c=colors[iter])
plt.title('Purely Random Fisher information')
Out[332]:
Text(0.5, 1.0, 'Purely Random Fisher information')
No description has been provided for this image

mmmm. Using random data, the iterations still appear to grow away from zero, although not nearly as clearly.

This seems to be an instability in the iterations themselves.

In [361]:
dist = AngleDistribution(np.ones(500), [0, np.pi])
colors = plt.cm.viridis(np.linspace(0,1,8))

to_sample = untrained_data[untrained_data.measurement == 'FI']

plt.subplots(2,1)
vars = []
for iter in range(8):
    sample_mean = 100*np.random.rand(500) +100

    new_probs = dist.values / sample_mean
    new_probs = new_probs / new_probs.sum()
    dist = AngleDistribution(new_probs, [0, np.pi])
    vars.append( dist.bin_probs.var() )

    plt.subplot(2,1,1)
    plt.plot(dist.bin_probs, c=colors[iter])
plt.title('Random - no smoothing')

plt.subplot(2,1,2)
plt.title('Variance')
plt.plot(vars, '.')
Out[361]:
[<matplotlib.lines.Line2D at 0x170ab3fd0>]
No description has been provided for this image
In [374]:
dist1 = AngleDistribution(np.ones(500), [0, np.pi])
dist2 = AngleDistribution(np.ones(500), [0, np.pi])
colors = plt.cm.viridis(np.linspace(0,1,8))

to_sample = untrained_data[untrained_data.measurement == 'FI']
v1 = []
v2 = []

plt.subplots(3,1)
for iter in range(8):
    sample_mean = 100*np.random.rand(500) +100

    new_probs1 = dist1.values / sample_mean
    new_probs1 = new_probs1 / new_probs1.sum()
    dist1 = AngleDistribution(new_probs1, [0, np.pi])
    v1.append( dist1.bin_probs.var() )
    
    new_probs2 = dist2.values / moving_average(sample_mean)
    new_probs2 = new_probs2 / new_probs2.sum()
    dist2 = AngleDistribution(new_probs2, [0, np.pi]) 
    v2.append( dist2.bin_probs.var() )
    
    plt.subplot(3,1,1)
    plt.title('No online smoothing - smoothed for plotting')
    plt.plot(moving_average(dist1.bin_probs), c=colors[iter])
    plt.subplot(3,1,2)
    plt.title('With online smoothing')
    plt.plot(dist2.bin_probs, c=colors[iter])

plt.subplot(3,1,1)
plt.suptitle('Head to head - smoothing or no');

plt.subplot(3,1,3)
plt.title('variance')
plt.plot(v1, '.', label='Not smoothed')
plt.plot(v2, '.', label='Smoothed')
plt.legend()
Out[374]:
<matplotlib.legend.Legend at 0x171983f70>
No description has been provided for this image

Result¶

Ok, this is very illuminating: it looks like a combination of the online smoothing, combined with compounding noise that causes the divergence.

Noise¶

This is caused by a lack of mean-reversion: there is nothing pulling the distribution back toward uniform, so successive iterations simply diffuse further away on average. I was hoping that the networks themselves would do the mean reversion. Absent that, the diffusion is inevitable. This can be seen in the linearly increasing variance across iterates.

Smoothing¶

When we don't smooth, the iterations are all independent noisy samples, with linearly increasing variance. Smoothing introduces some serial dependence between the samples. By eye, this doesn't cause the deviation to be worse. It does make the variance increase more slowly ( no visible, but it is still increasing)

Important to dos:¶

  1. normalize the distribution values when they are updated to avoid round-off error. ✅
  2. investigate the uniform and concentrated networks: their trajectories are highly correlated.
    • step 1: save weights only in the checkpoints ✅
    • step 2: compare retraining with and without saving only the weights: does the results still hold? ✅
  3. Mean reversion scale of networks: what size of perturbations (width or height-wise) does the network revert?
In [ ]:
 

Addressing the similarity of pre-trained networks¶

How did the hyperparameters change with iterations in the first experiment?¶

Maybe if the optimizer parameters shrink down to be really small, this would explain what we see.

In [115]:
model_data0 = torch.load('trainedParameters/Exp6/untrained/iter0/epoch=192-step=24704.ckpt',
                        map_location='cpu')

model_data1 = torch.load('trainedParameters/Exp6/untrained/iter1/epoch=149-step=19200.ckpt',
                        map_location='cpu')
model_data_init = torch.load('trainedParameters/Exp6/untrained.ckpt')
In [118]:
model_data1['optimizer_states'][0]
Out[118]:
{'state': {0: {'step': tensor(19200.),
   'exp_avg': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]),
   'exp_avg_sq': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]])},
  1: {'step': tensor(19200.),
   'exp_avg': tensor([ 3.9950e-08,  2.2974e-08,  3.7015e-08,  2.8690e-08, -1.5791e-08,
           -2.2150e-08,  1.8757e-08, -1.8802e-08, -4.8500e-08,  9.0094e-08,
            4.2652e-09,  4.8591e-09, -1.3988e-08,  5.3582e-08,  5.0246e-08,
            1.8181e-08, -5.6266e-08,  1.1453e-07,  1.7005e-07, -8.2164e-09,
           -1.8664e-07, -1.5436e-09, -1.1229e-07, -8.8392e-08, -9.8638e-08,
            2.1685e-07, -1.0704e-08,  2.6447e-08,  1.4891e-08, -3.2266e-08,
           -4.2134e-08,  1.2371e-08,  6.0001e-08, -5.4728e-08, -2.7055e-08,
           -3.1886e-09,  4.7419e-08, -8.4212e-08,  7.0370e-08,  3.9694e-08,
           -2.9081e-08, -1.9754e-08,  5.9768e-08, -4.7940e-08, -8.3292e-08,
            3.1447e-08,  2.9186e-07,  3.7023e-09, -9.7024e-08, -5.2214e-08,
           -3.2593e-08,  4.0332e-09, -2.7107e-08, -7.3101e-08, -1.3235e-08,
           -1.7992e-07,  2.5443e-07,  1.5715e-08, -2.3786e-08,  5.1012e-08,
           -1.5511e-07, -2.4289e-08, -4.6096e-08,  1.9318e-08, -2.1846e-09,
            1.5979e-08,  4.0604e-08,  3.1175e-08,  6.5061e-08, -1.9150e-08,
           -8.3222e-09,  2.1977e-08, -1.8897e-08, -1.5616e-08, -5.8125e-09,
           -4.1069e-08, -6.1176e-08, -1.8599e-08, -9.8040e-08,  2.1829e-08,
            9.3983e-09, -4.0516e-08, -4.6373e-08, -7.3704e-08, -1.8467e-08,
           -3.2379e-08, -6.0409e-08, -1.0234e-07, -6.5500e-08, -2.0492e-07,
           -6.5358e-08, -5.3337e-09,  2.9970e-08,  3.5137e-09,  3.9394e-08,
            8.7156e-08, -9.1030e-09, -1.9250e-08, -4.0935e-08, -9.0614e-08,
            6.6389e-08]),
   'exp_avg_sq': tensor([2.9767e-13, 1.1291e-13, 1.5908e-13, 8.0122e-14, 2.2190e-14, 4.9398e-13,
           6.5349e-14, 3.8010e-14, 1.6380e-13, 3.9109e-13, 6.3192e-14, 1.0245e-13,
           2.2673e-13, 1.9131e-13, 2.0234e-13, 1.4576e-13, 2.5629e-13, 3.0221e-13,
           4.2077e-13, 4.0549e-14, 5.4309e-13, 4.8692e-13, 1.6738e-13, 4.7163e-13,
           1.3034e-13, 4.8022e-13, 1.3200e-14, 5.7385e-13, 2.4600e-13, 1.2623e-13,
           1.9026e-13, 3.7299e-13, 2.8063e-13, 1.9123e-13, 2.8781e-13, 5.7911e-13,
           2.9629e-13, 1.8802e-13, 2.7304e-13, 3.4340e-13, 1.7665e-13, 1.9906e-13,
           2.8867e-13, 8.1971e-14, 7.1792e-13, 2.4038e-13, 3.5841e-13, 4.3775e-13,
           3.7669e-13, 2.9406e-13, 1.7000e-13, 5.8291e-14, 1.1526e-13, 1.9875e-13,
           3.2290e-13, 6.8843e-13, 1.0436e-12, 6.7710e-14, 2.2433e-13, 5.9616e-13,
           4.3518e-13, 4.7085e-13, 5.1738e-13, 1.7276e-13, 1.6865e-13, 1.8499e-13,
           2.2188e-13, 4.4883e-13, 2.5631e-13, 7.5779e-14, 1.7237e-13, 9.1031e-13,
           4.4974e-13, 1.0054e-13, 5.7251e-13, 9.2990e-14, 1.2516e-12, 6.2778e-14,
           3.5808e-13, 3.7556e-14, 5.4941e-13, 1.0891e-13, 6.5609e-14, 2.7047e-13,
           2.1465e-13, 1.4819e-14, 1.0981e-13, 1.1930e-13, 5.1540e-13, 4.6762e-13,
           7.3987e-14, 5.0239e-13, 7.7537e-14, 1.1692e-13, 1.1469e-13, 2.5231e-13,
           2.2377e-14, 1.9400e-13, 7.7623e-14, 2.7656e-13, 4.8752e-13])},
  2: {'step': tensor(19200.),
   'exp_avg': tensor([[ 3.9751e-07,  3.5196e-07,  1.0579e-06,  ..., -1.1782e-06,
             9.9092e-07,  1.2076e-06],
           [-4.0587e-09, -7.7670e-08, -7.0465e-09,  ..., -4.8705e-07,
            -3.4182e-06, -1.1919e-08],
           [ 5.7123e-06,  4.2821e-06,  4.2875e-06,  ..., -2.8615e-06,
             8.7897e-09,  4.1755e-06],
           ...,
           [-8.2580e-07,  1.5665e-06, -1.2411e-06,  ..., -3.0074e-07,
            -8.0529e-07, -2.2139e-07],
           [-9.0013e-06,  1.1886e-06, -5.8878e-08,  ..., -1.5740e-06,
            -1.2207e-06, -8.4139e-06],
           [ 1.9409e-08, -4.9159e-09, -4.8037e-08,  ..., -2.8185e-07,
             1.8700e-07,  2.9006e-08]]),
   'exp_avg_sq': tensor([[2.6704e-12, 2.2422e-12, 8.4092e-12,  ..., 1.5337e-10, 6.1572e-09,
            3.7416e-10],
           [1.1404e-12, 1.3320e-12, 1.5067e-12,  ..., 1.0863e-11, 4.8380e-10,
            1.1343e-12],
           [2.0850e-10, 1.9130e-09, 2.0834e-09,  ..., 1.5120e-09, 6.5665e-12,
            9.9315e-10],
           ...,
           [1.9818e-09, 4.2680e-10, 2.3925e-09,  ..., 9.5324e-11, 7.4941e-11,
            1.4853e-09],
           [1.2304e-09, 1.0797e-09, 1.7017e-09,  ..., 1.1278e-10, 2.8729e-11,
            6.6484e-10],
           [2.7203e-13, 1.6071e-13, 4.1965e-13,  ..., 5.4241e-12, 6.0056e-12,
            2.7558e-13]])},
  3: {'step': tensor(19200.),
   'exp_avg': tensor([-7.0225e-08, -4.9133e-07,  3.5006e-07, -2.1929e-07,  6.0389e-08,
            3.3106e-07,  1.8911e-07,  2.8392e-07,  3.5667e-07, -7.3169e-07,
            2.7290e-07,  4.8043e-07, -1.6849e-07,  2.7576e-08,  6.6279e-07,
           -2.6198e-07,  1.1365e-37, -2.5275e-08, -3.3905e-07, -4.7240e-08]),
   'exp_avg_sq': tensor([9.8487e-12, 4.2503e-12, 1.7769e-11, 5.5501e-12, 5.3832e-13, 2.3923e-11,
           7.1915e-13, 1.4934e-11, 2.8864e-12, 9.7026e-12, 4.1048e-11, 5.9378e-12,
           1.2309e-11, 3.2376e-13, 5.4346e-12, 8.4505e-12, 3.2725e-14, 1.0349e-11,
           4.4490e-12, 1.2267e-12])},
  4: {'step': tensor(19200.),
   'exp_avg': tensor([[ 7.4749e-05,  4.3184e-05,  9.7801e-05, -7.2232e-05, -7.2613e-06,
            -2.7013e-05, -1.2499e-05,  3.3588e-04,  9.9370e-05,  2.2992e-04,
             7.6647e-05,  1.9692e-04, -8.3657e-05,  1.0178e-06, -6.2391e-07,
            -7.0467e-05, -1.1380e-37, -1.4542e-04, -4.0040e-05,  2.8469e-06],
           [-2.1076e-06, -4.5587e-06, -6.6048e-05,  5.0901e-05, -9.4604e-06,
            -7.1190e-05, -2.7243e-05,  1.1919e-04,  7.2453e-05,  2.5918e-05,
            -1.8415e-04,  8.0264e-05,  1.5577e-05, -4.3151e-06,  1.1833e-04,
             3.5031e-05,  1.1247e-37,  1.6499e-04, -5.2236e-05, -5.8078e-05]]),
   'exp_avg_sq': tensor([[5.3103e-07, 1.3128e-07, 1.6914e-06, 2.5481e-07, 6.5761e-08, 1.6418e-06,
            4.1097e-08, 2.3069e-06, 3.2923e-07, 1.0400e-06, 1.1325e-06, 6.9244e-07,
            2.4119e-07, 4.3331e-08, 7.8954e-07, 5.3556e-07, 1.9832e-15, 5.1401e-07,
            1.1419e-07, 1.1320e-08],
           [1.7107e-06, 6.4376e-07, 4.6067e-07, 3.2384e-07, 2.4660e-08, 5.3090e-07,
            5.7447e-08, 5.0411e-07, 2.6994e-07, 1.1136e-06, 2.7019e-06, 1.4653e-07,
            6.5892e-07, 8.6931e-09, 4.0941e-07, 3.5560e-07, 2.4615e-14, 1.3871e-06,
            1.0544e-06, 2.0632e-07]])},
  5: {'step': tensor(19200.),
   'exp_avg': tensor([ 4.1278e-06, -1.7324e-07]),
   'exp_avg_sq': tensor([1.0909e-09, 1.1236e-09])}},
 'param_groups': [{'lr': 0.001,
   'betas': (0.9, 0.999),
   'eps': 1e-08,
   'weight_decay': 0,
   'amsgrad': False,
   'maximize': False,
   'foreach': None,
   'capturable': False,
   'differentiable': False,
   'fused': None,
   'params': [0, 1, 2, 3, 4, 5]}]}
In [119]:
model_data0['optimizer_states'][0]
Out[119]:
{'state': {0: {'step': tensor(24704.),
   'exp_avg': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]),
   'exp_avg_sq': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]])},
  1: {'step': tensor(24704.),
   'exp_avg': tensor([-3.3291e-09, -2.6901e-07, -8.5368e-09, -6.1689e-09, -5.8358e-10,
           -4.2298e-08,  1.2142e-08, -8.4066e-08, -1.1084e-08,  6.6887e-08,
            1.1023e-08, -3.7201e-08, -7.6836e-09,  2.2619e-08, -5.2994e-09,
           -6.2287e-08, -9.4547e-10, -1.4335e-09,  3.8968e-08,  2.2513e-08,
            2.2364e-09, -1.1143e-09,  3.6313e-08, -4.9238e-08, -9.3458e-09,
            6.9197e-08,  1.5287e-09,  8.8791e-09, -4.9362e-08, -1.3478e-09,
           -3.2514e-08,  2.3516e-09,  3.7128e-09,  8.7992e-08, -7.9905e-08,
           -9.8589e-08, -1.2815e-08,  2.1365e-08,  1.5157e-08, -4.6912e-09,
           -2.3465e-08, -1.3516e-08,  3.5132e-08, -7.1798e-09,  2.0962e-08,
            2.3280e-08, -1.4391e-09, -3.2664e-09,  1.0576e-08, -2.7358e-08,
            3.9048e-08,  4.3482e-08, -8.9489e-10,  9.2530e-09,  1.1110e-07,
            9.3799e-09, -1.1647e-07,  4.5923e-10,  2.5281e-08, -4.9357e-10,
            2.8122e-09, -1.6036e-08, -5.4464e-08,  3.9694e-09, -1.2456e-08,
           -3.0809e-09,  6.5620e-08, -2.5560e-09,  2.5220e-09, -8.9181e-10,
           -1.0684e-08,  1.1933e-12, -1.4719e-08, -3.6113e-08, -4.9516e-08,
            1.0140e-08,  2.9423e-08,  1.9112e-08,  2.4743e-08, -3.2870e-09,
            7.1025e-09, -2.6974e-10, -3.5040e-08, -1.4150e-08,  9.0055e-09,
            4.8752e-08,  2.4106e-08, -6.5053e-09,  8.2710e-09,  4.9359e-09,
           -3.2327e-08,  1.4464e-08, -1.0219e-10,  4.3338e-08,  1.3530e-09,
            2.3589e-08,  6.5122e-08, -7.3957e-09,  1.0636e-08,  7.5874e-08,
            6.5495e-08]),
   'exp_avg_sq': tensor([3.3493e-14, 4.5770e-13, 2.7220e-15, 8.5458e-15, 8.8367e-14, 1.5491e-13,
           1.1714e-14, 2.7347e-13, 5.3222e-15, 1.6002e-13, 3.1381e-15, 1.0114e-13,
           7.7224e-15, 8.3938e-15, 7.0083e-14, 1.1019e-13, 1.1026e-14, 4.1827e-16,
           1.5740e-13, 1.6618e-13, 5.8237e-14, 6.8468e-15, 3.6846e-14, 1.2065e-13,
           2.5912e-14, 1.2317e-13, 1.8945e-15, 5.6746e-14, 1.4344e-13, 5.8940e-15,
           7.5823e-14, 5.9772e-16, 1.3587e-13, 1.8097e-13, 8.6020e-14, 8.1786e-14,
           5.7440e-14, 4.1861e-14, 2.8075e-15, 6.4289e-16, 3.3174e-14, 8.5489e-15,
           4.2327e-14, 3.2144e-15, 1.7530e-13, 6.0951e-14, 7.1471e-15, 1.7221e-15,
           3.5197e-13, 8.2597e-14, 1.0789e-13, 3.9281e-14, 1.4633e-15, 1.6018e-14,
           2.5271e-13, 3.3406e-14, 2.1086e-13, 1.2904e-14, 8.8726e-14, 1.6301e-16,
           1.5266e-13, 7.2683e-15, 3.6584e-14, 3.2395e-15, 2.5544e-15, 1.0577e-13,
           2.3522e-13, 3.2286e-14, 1.4691e-14, 5.5773e-16, 2.3658e-14, 5.9801e-15,
           8.8853e-15, 3.0933e-14, 5.9003e-14, 4.8041e-14, 1.8997e-14, 3.3076e-14,
           3.7029e-14, 1.7951e-15, 1.5096e-14, 9.8980e-14, 3.4488e-13, 3.3717e-15,
           7.9273e-14, 6.2698e-14, 3.8048e-14, 9.2425e-16, 1.9873e-14, 3.6993e-15,
           5.3648e-14, 8.3630e-14, 4.9177e-15, 1.7690e-14, 7.0701e-14, 9.8439e-14,
           8.8344e-14, 7.4604e-14, 3.3676e-14, 8.5277e-14, 2.5562e-13])},
  2: {'step': tensor(24704.),
   'exp_avg': tensor([[-4.9326e-08,  1.1744e-06,  4.6956e-07,  ...,  6.5900e-07,
            -1.2494e-06, -9.6769e-06],
           [-1.0783e-37,  1.1338e-37,  1.0727e-37,  ...,  1.1544e-37,
             1.1560e-37,  1.1439e-37],
           [-4.5249e-08,  5.8841e-06,  2.7241e-06,  ...,  3.0290e-06,
             2.4178e-07,  5.7643e-06],
           ...,
           [ 3.2185e-06,  9.8147e-07, -8.6194e-08,  ...,  1.5874e-06,
            -1.3662e-06,  1.3513e-06],
           [-1.6921e-06, -6.0535e-06, -4.8379e-07,  ..., -9.9990e-07,
            -1.9927e-08,  1.5375e-09],
           [-4.0679e-10,  9.8492e-08,  7.0411e-10,  ..., -8.5458e-07,
            -6.8398e-07, -4.3051e-06]]),
   'exp_avg_sq': tensor([[2.0310e-13, 4.3699e-12, 3.3615e-12,  ..., 1.2814e-12, 1.6978e-09,
            4.6638e-09],
           [1.2633e-17, 3.5333e-15, 7.5548e-17,  ..., 7.1894e-17, 6.3369e-18,
            6.6643e-17],
           [6.1298e-13, 4.1991e-09, 2.4885e-10,  ..., 3.1910e-09, 8.7600e-11,
            1.6563e-09],
           ...,
           [2.5256e-10, 1.0209e-09, 1.4628e-11,  ..., 2.9486e-10, 2.7519e-11,
            6.9032e-10],
           [1.6946e-10, 2.7459e-09, 1.0914e-10,  ..., 2.5188e-11, 4.4105e-13,
            2.9463e-12],
           [8.0220e-14, 1.6896e-12, 9.2324e-13,  ..., 1.1071e-10, 5.8465e-10,
            2.4667e-10]])},
  3: {'step': tensor(24704.),
   'exp_avg': tensor([-1.7259e-07,  1.1149e-37,  5.9648e-07, -1.6012e-07, -4.0144e-08,
           -1.4859e-07,  3.2276e-09,  7.5898e-08, -9.3215e-09, -1.5498e-07,
            1.5582e-07, -9.4949e-08, -5.6716e-08,  8.2442e-08,  1.6503e-07,
           -4.5831e-08, -1.3578e-07,  7.7463e-08, -9.9964e-08, -7.8850e-08]),
   'exp_avg_sq': tensor([2.3174e-12, 1.1540e-15, 8.4657e-12, 1.0602e-12, 8.4984e-13, 3.4273e-12,
           1.1004e-13, 1.8538e-12, 1.0229e-12, 4.6470e-12, 1.4870e-12, 7.9469e-13,
           1.0878e-12, 1.0385e-12, 6.9255e-13, 6.6388e-13, 8.0358e-13, 9.1733e-13,
           4.7585e-13, 3.2537e-13])},
  4: {'step': tensor(24704.),
   'exp_avg': tensor([[ 1.3879e-05, -1.0978e-37,  3.1283e-06, -3.2041e-05, -8.3746e-05,
            -5.4410e-05, -2.0653e-06, -6.0644e-05,  1.2586e-05,  5.4365e-05,
            -2.8233e-05, -2.4750e-05, -3.0203e-05, -7.5921e-05,  1.1861e-05,
            -1.0003e-05,  5.6459e-06,  2.1341e-05, -1.3047e-05, -8.5711e-05],
           [-1.9475e-04, -1.0633e-37, -1.5675e-04,  4.5946e-05, -5.8112e-05,
             8.7905e-06,  4.8987e-05, -1.8739e-05,  5.5974e-06,  2.7959e-05,
            -1.6353e-05, -2.4068e-05,  9.0900e-06, -2.8180e-05,  2.1717e-05,
             3.6198e-05, -4.8609e-05,  4.2364e-05,  7.9464e-05, -3.7461e-05]]),
   'exp_avg_sq': tensor([[5.7826e-08, 8.0786e-16, 8.8738e-07, 8.2365e-08, 2.7641e-07, 2.6401e-07,
            6.2966e-09, 5.2752e-07, 5.1849e-07, 1.7811e-06, 7.1370e-07, 3.9264e-07,
            1.1426e-07, 4.0743e-07, 2.1667e-07, 3.9353e-08, 3.2557e-09, 1.7091e-07,
            3.6437e-08, 2.6345e-07],
           [1.0502e-06, 7.0775e-16, 1.4466e-06, 2.3260e-07, 1.5807e-07, 1.5568e-07,
            7.9575e-08, 1.5304e-07, 1.7943e-07, 2.6096e-06, 1.5657e-07, 2.4886e-08,
            1.2239e-07, 3.9524e-08, 3.6080e-07, 1.9693e-07, 4.9994e-08, 2.4139e-07,
            3.6219e-07, 3.9706e-08]])},
  5: {'step': tensor(24704.),
   'exp_avg': tensor([-6.4559e-07, -1.2125e-06]),
   'exp_avg_sq': tensor([2.0056e-10, 2.9153e-10])}},
 'param_groups': [{'lr': 0.001,
   'betas': (0.9, 0.999),
   'eps': 1e-08,
   'weight_decay': 0,
   'amsgrad': False,
   'maximize': False,
   'foreach': None,
   'capturable': False,
   'differentiable': False,
   'fused': None,
   'params': [0, 1, 2, 3, 4, 5]}]}
In [117]:
model_data_init['optimizer_states'][0]['param_groups']
Out[117]:
[{'lr': 0.001,
  'betas': (0.9, 0.999),
  'eps': 1e-08,
  'weight_decay': 0,
  'amsgrad': False,
  'maximize': False,
  'foreach': None,
  'capturable': False,
  'differentiable': False,
  'fused': None,
  'params': [0, 1, 2, 3, 4, 5]}]

Hard to get much out of these numbers. The basic parameters aren't changing much.

In [ ]:
 

Are the results different if we save only the weights during pretraining?¶

Experiment: initialize models, pretrain saving either weights only or not, compare.

In [6]:
init_data = pd.read_pickle('experiment_result/ex6_init2.pickle')
In [22]:
init_data.sort_values(['method', 'rep'])
Out[22]:
rep method Fisher
8 0 all [9731.43904304492, 8632.25088476819, 9236.8925...
9 0 all [9154.476091599896, 9358.899765289718, 7723.10...
10 0 all [7900.043777763896, 6880.582472615025, 7635.87...
11 0 all [7668.709673804187, 8520.00988880094, 9665.104...
12 0 all [8073.243401894655, 7885.193828813932, 8202.21...
13 0 all [9092.411254991597, 9974.333940118238, 8498.38...
22 1 all [4716.853426608467, 5987.776276603933, 4814.12...
23 1 all [5251.742508812979, 5046.3991082111115, 5478.3...
24 1 all [6491.01862820697, 5418.43973423646, 5068.2023...
25 1 all [5574.396421648473, 6090.717359568154, 5532.34...
26 1 all [6283.682062739737, 5033.9707862196365, 5257.4...
27 1 all [5736.774967734029, 5660.824241487151, 5224.84...
36 2 all [7245.134008848942, 7376.052201136431, 7598.22...
37 2 all [5821.538205425657, 5957.800981747067, 6563.98...
38 2 all [6823.262177726626, 7335.203080726032, 6860.93...
39 2 all [7724.908343458281, 6320.398675279556, 8219.12...
40 2 all [8775.635508300766, 8271.719569157684, 8201.03...
41 2 all [7256.410774290576, 7065.2717751842865, 7735.6...
0 0 init [94.97747642708468, 109.0764109552019, 128.875...
7 0 init [8.9423179738955, 9.557862311460575, 11.175286...
14 1 init [63.87981372499806, 64.30695170332146, 64.7981...
21 1 init [67.886799586104, 70.30323071279315, 68.704659...
28 2 init [0.34414571095275365, 0.07857253573667755, 0.1...
35 2 init [11.015241842309006, 15.101778242178428, 12.58...
1 0 weights [7489.937823357886, 7298.444890711764, 6808.25...
2 0 weights [6866.899285417158, 6418.306668619433, 7346.95...
3 0 weights [6402.598569271286, 6265.5348798878495, 6175.5...
4 0 weights [9434.350410260768, 7725.35230058199, 8216.120...
5 0 weights [7756.657083003193, 6857.583918873309, 8427.94...
6 0 weights [8042.877249508951, 7283.056005857774, 7762.27...
15 1 weights [8534.682269248882, 9713.9064469059, 7807.6463...
16 1 weights [9984.304641736942, 7825.095072366619, 8354.00...
17 1 weights [11365.563566082628, 11240.831909742921, 9315....
18 1 weights [8846.371134361654, 9038.215308570376, 8712.06...
19 1 weights [9737.283185256647, 8907.589827069598, 9443.63...
20 1 weights [9547.901531240255, 10699.910357082279, 9155.6...
29 2 weights [7360.646468316156, 7404.83808889236, 6220.085...
30 2 weights [6257.006131289272, 7131.237135011165, 6421.76...
31 2 weights [6734.224614825596, 5947.305354442145, 6991.21...
32 2 weights [7582.8126041177775, 6751.791311037392, 7093.8...
33 2 weights [6073.9006480561175, 6102.018964492787, 6490.7...
34 2 weights [5292.418137873429, 6522.818005062166, 6758.61...
In [64]:
plt.title('Cov - All, init, weights')
sns.heatmap(np.corrcoef( np.array(init_data.sort_values(['method', 'rep'])['Fisher'].to_list() )))
Out[64]:
<Axes: title={'center': 'Cov - All, init, weights'}>
No description has been provided for this image
In [57]:
initial_runs = np.array(init_data[ init_data.method == 'init']['Fisher'].to_list())

figsize(4,3.2)
plt.title('initial runs - removing mean')
sns.heatmap(np.corrcoef( initial_runs - initial_runs.mean(0) ))
Out[57]:
<Axes: title={'center': 'initial runs - removing mean'}>
No description has been provided for this image
In [41]:
colors = plt.cm.viridis(np.linspace(0,1,3))

figsize(20, 20)
plt.subplots(3,1)

plt.subplot(3,1,1)
plt.title('Initialization')
for row in init_data[ init_data.method == 'init'].to_dict(orient="records"):
    plt.plot(row['Fisher'], c=colors[row['rep']])

plt.subplot(3,1,2)
plt.title('All')
for row in init_data[ init_data.method == 'all'].to_dict(orient="records"):
    plt.plot(row['Fisher'], c=colors[row['rep']])

plt.subplot(3,1,3)
plt.title('weights')
for row in init_data[ init_data.method == 'weights'].to_dict(orient="records"):
    plt.plot(row['Fisher'], c=colors[row['rep']])
No description has been provided for this image

Ok. These results are very interesting¶

  1. Fisher curves with the same initialization are correlated regardless of how the model is saved.

    • this can be seen in the curves themselves, more similar within colors than between
    • it is also shown in correlation plot by the squares along the diagonal, spaced 6 each.
  2. The correlation in the initilization curves is due to the non-uniform distribution: removing the mean removes the correlation for the most part.

    • However, it does introduce negative correlation, presumably because the mean is computed from these trajectories themselves.
  3. The weight-only curves are very correlated across initializations, and very correlated to the Fisher information in initialization runs.

    • This hold regardless of whether the initilization run corresponded to the particular weight-only curve in question, which suggests that it is the result of failure to remove the initial Fisher information trend during the weight-only training.
In [ ]:
 
In [134]:
copy = init_data[init_data.method == 'init']
copy.reset_index()

mean =  copy['Fisher'].mean()
mean_sub = copy.apply(lambda r: r['Fisher'] - mean, axis=1)
method_sub = copy.apply(lambda r: 'init_mean_rm', axis=1)

copy.loc[:,'Fisher'] = mean_sub
copy.loc[:,'method'] = method_sub
In [140]:
init_data_aug = pd.concat([init_data, copy])
plt.title('Cov - including mean removal')
sns.heatmap(np.corrcoef( np.array(init_data_aug.sort_values(['method', 'rep'])['Fisher'].to_list() )))
Out[140]:
<Axes: title={'center': 'Cov - including mean removal'}>
No description has been provided for this image

Introducing mean removal¶

  • again, there is a sizable degree of negative correlation

  • The mean removed versions seem to show little correlation to either the weight-only or the all-parameter models.

  • However, they do pick-out quite well the initialization run that each of the re-training results is based on.

Initialization conclusions:¶

There is, in fact, a sizable impact of initialization on the networks that are learned. This holds regardless of how the network is saved and reloaded. Saving network and training parameters beyond the weights seems to make the retraining more effective at removing general trends from the network, but the other initialization effects remain.

On one hand, this is exactly the type of effect that I was hoping to see: there is a residual of the pre-training that can be dected. It is very interesting that saving the state of the trainer is sufficient to remove it.

On the other hand, there remains initialization dependent correlation. Perhaps this is not altogether that surprising: the network doesn't need to unlearn quirks of the initialization if these don't impact generalization ability of the network. In a network that generalizes very well after the initial training, we wouldn't expect the second round of training to do much, if anything, with data from the same distribution, freshly sampled.

In [ ]: